In [1]:
#################################################################################
#####   INF889E - Méthodes d'intelligence artificielle en bioinformatique   #####
#####             Classification de VIH par zones géographiques             #####
#################################################################################
#####   Author: Riccardo Z******                                            #####
#####   This program is partly inspired by the work presented in a class    #####
#####   workshop by Dylan Lebatteux.                                        #####
#################################################################################
In [2]:
# Import functions
import re
import joblib
import numpy as np
import pandas as pd
from os import listdir
from random import shuffle
from progressbar import ProgressBar
from Bio import SeqIO, pairwise2
from Bio.motifs import create
from sklearn import svm
from sklearn.preprocessing import MinMaxScaler
from sklearn.feature_selection import RFE
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
from dna_features_viewer import GraphicFeature, GraphicRecord
import matplotlib.pyplot as plt
import seaborn as sns
In [3]:
##############################################
#####        IMPORTANT VARIABLES         #####
##############################################
In [4]:
# Scope of classification: if "ALL", classify by region globaly
# If AFR, ASI, CAM, CRB, EUR, FSU, MEA, NAM, OCE or SAM, classify by country within this chosen region 
scope = "AFR"
# Access path for the FASTA files (one file for each region)
path = "../../../../data/" + "all"
# Name of trained model when saving
model_name = "afr.pkl"
# For sampling purposes: will process max n sequences for each target class
n_samples = 2000
# Classification features as sum (false) of motifs or as frequency (true) of motifs
freq = False
# Elimination step for features selection
step = 5
# Number of features to select
n_features = 100
# Train / Test split ratio
split_raito = 0.8
# Dimensions for graphs (2D or 3D)
n_components = 2
# Set maximum number of incorrect records to analyse at the end
max_incorrect = 10
# Set maximum number of correct records to compute alignment with at the end
max_correct = 1000
# Set the length k of the features based on k-mers
k = 5
In [5]:
##############################################
#####        DATA INITIALISATION         #####
##############################################
print("\n         DATA INITIALISATION         ")
print("=====================================")
         DATA INITIALISATION         
=====================================
In [6]:
# Will contain all the data rows, in the form of biopython seq_records objects
data = []
# Will contain a pair of target class -> number of data rows with this target class
targets = {}
# Process raw record label information into its annotations, then insert it into data
# To update if the label of sequences in the FASTA files changes
def add_record(record, target):
    # Initialiation of the seq_record
    header = record.id.split(".")
    record.id = header[4]
    record.name = header[3]
    record.seq = record.seq.upper()
    record.annotations = {"target": target, "subtype": header[0], "country": header[1]}
    # Add it to the data table and update the target classes dictionary
    targets[target] = targets.get(target, 0) + 1
    data.append(record)
In [7]:
# Properly fills the data table using the above function
if scope == "ALL":
    # Used to show progress
    progress = ProgressBar()
    # If scope is ALL, each filename is the name of each region used as a target class
    for filename in progress(sorted(listdir(path))):
        target = filename.split('.')[0]
        for record in SeqIO.parse(path + "/" + filename, "fasta"):
            add_record(record, target)
    print("")
else:
    # Else, countries are target classes, and the scope region is the filename
    for record in SeqIO.parse(path + "/" + scope + ".fasta", "fasta"):
        target = record.id.split(".")[1]
        add_record(record, target)
In [8]:
# Dipslay data information
print("Data information:")
print("Number of sequences:", sum(targets.values()))
print("Number of targets:", len(targets))
print("Minimum number of instances:", min(targets.values()))
print("Maximum number of instances:", max(targets.values()))

# Dipslay data summary
print("\nData summary:")
for key, value in targets.items(): 
    print("Target:", key, "| Number of sequences:", value)

# Display the first 5 samples
print("\nInformation of the first 5 samples:")
for i in range(5):
    print("ID:", data[i].id, "| Sequence:", data[i].seq[0:50], "| Annotations:", data[i].annotations)
Data information:
Number of sequences: 190454
Number of targets: 50
Minimum number of instances: 1
Maximum number of instances: 36174

Data summary:
Target: GH | Number of sequences: 1860
Target: ZM | Number of sequences: 12384
Target: UG | Number of sequences: 29036
Target: RW | Number of sequences: 3158
Target: ET | Number of sequences: 5214
Target: LR | Number of sequences: 77
Target: DJ | Number of sequences: 97
Target: SN | Number of sequences: 2006
Target: CD | Number of sequences: 2947
Target: KE | Number of sequences: 22982
Target: CM | Number of sequences: 14299
Target: CI | Number of sequences: 1145
Target: GQ | Number of sequences: 342
Target: CF | Number of sequences: 1554
Target: ZW | Number of sequences: 2392
Target: ZA | Number of sequences: 36174
Target: TZ | Number of sequences: 10410
Target: MZ | Number of sequences: 1136
Target: CG | Number of sequences: 200
Target: NG | Number of sequences: 3596
Target: GA | Number of sequences: 477
Target: BW | Number of sequences: 11792
Target: GM | Number of sequences: 72
Target: SO | Number of sequences: 24
Target: TN | Number of sequences: 169
Target: EG | Number of sequences: 39
Target: MW | Number of sequences: 18717
Target: GW | Number of sequences: 1411
Target: BJ | Number of sequences: 471
Target: MA | Number of sequences: 819
Target: AO | Number of sequences: 790
Target: BF | Number of sequences: 823
Target: LY | Number of sequences: 147
Target: ML | Number of sequences: 438
Target: NE | Number of sequences: 285
Target: TD | Number of sequences: 255
Target: BI | Number of sequences: 542
Target: SD | Number of sequences: 48
Target: ER | Number of sequences: 1
Target: DZ | Number of sequences: 473
Target: MG | Number of sequences: 73
Target: SC | Number of sequences: 93
Target: SZ | Number of sequences: 109
Target: TG | Number of sequences: 862
Target: GN | Number of sequences: 146
Target: MR | Number of sequences: 48
Target: CV | Number of sequences: 129
Target: SL | Number of sequences: 179
Target: - | Number of sequences: 1
Target: RE | Number of sequences: 12

Information of the first 5 samples:
ID: AB049811 | Sequence: TGGATGGGCTAATTTACTCCAAGAAAAGACAAGAGATCCTTGATCTGTGG | Annotations: {'target': 'GH', 'subtype': '02_AG', 'country': 'GH'}
ID: AB050905 | Sequence: GAAGAAGGGATAATAATTAGATCTGAGAATCTGACAAACAATGCCAAAAC | Annotations: {'target': 'ZM', 'subtype': 'C', 'country': 'ZM'}
ID: AB050906 | Sequence: GAAGAAGAGATAATAATTAGATCTGAAAATCTGGCAGACAATGTCAAAAC | Annotations: {'target': 'ZM', 'subtype': 'C', 'country': 'ZM'}
ID: AB050907 | Sequence: GAAAAAGACATAATAATTAGATCTGAAAATCTAACAAATAATATCAAAAC | Annotations: {'target': 'ZM', 'subtype': 'C', 'country': 'ZM'}
ID: AB050908 | Sequence: GAAAAAGACATAATAATTAGATCTGAAAATCTAACAAATAATATCAAAAC | Annotations: {'target': 'ZM', 'subtype': 'C', 'country': 'ZM'}
In [9]:
##############################################
#####      TRAIN / TEST DATA SPLIT       #####
##############################################
# Initialise train/test tables that will contain the data
train_data = []
test_data = []
# Initialise train/test dictionaries that will contain the number of instances for each target
test_split = {}
train_split = {}
# Initialise the dictionary with the targets keys and the value 0
test_split = test_split.fromkeys(targets.keys(), 0)
train_split = train_split.fromkeys(targets.keys(), 0)
# Shuffle the data
shuffle(data)
In [10]:
# Iterate through the data
for d in data:
    # Get this records's target class
    target = d.annotations["target"]
    # For sampling purposes: train/test threshold is based on n_samples if there is too much records for this target
    threshold = min(targets[target], n_samples) * split_raito
    # Until threshold for this target is reached, fills train data
    if train_split[target] < threshold: 
        train_data.append(d)
        train_split[target] += 1
    # Then, fills test data (until eventually n_samples are collected)
    elif test_split[target] < n_samples * (1-split_raito): 
        test_data.append(d)
        test_split[target] += 1
# Shuffle the data
shuffle(train_data)
shuffle(test_data)
In [11]:
# Data summary of the train/test split
print("\nTrain/Test split summary:")
for train_key, test_key in zip(train_split.keys(), test_split.keys()):
    print("Target:", train_key, "| Train instances:", train_split[train_key], "| Test instances:", test_split[test_key])
print("\nTotal number of training instances:", len(train_data))
print("Total number of testing instances:", len(test_data))
Train/Test split summary:
Target: GH | Train instances: 1488 | Test instances: 372
Target: ZM | Train instances: 1600 | Test instances: 400
Target: UG | Train instances: 1600 | Test instances: 400
Target: RW | Train instances: 1600 | Test instances: 400
Target: ET | Train instances: 1600 | Test instances: 400
Target: LR | Train instances: 62 | Test instances: 15
Target: DJ | Train instances: 78 | Test instances: 19
Target: SN | Train instances: 1600 | Test instances: 400
Target: CD | Train instances: 1600 | Test instances: 400
Target: KE | Train instances: 1600 | Test instances: 400
Target: CM | Train instances: 1600 | Test instances: 400
Target: CI | Train instances: 916 | Test instances: 229
Target: GQ | Train instances: 274 | Test instances: 68
Target: CF | Train instances: 1244 | Test instances: 310
Target: ZW | Train instances: 1600 | Test instances: 400
Target: ZA | Train instances: 1600 | Test instances: 400
Target: TZ | Train instances: 1600 | Test instances: 400
Target: MZ | Train instances: 909 | Test instances: 227
Target: CG | Train instances: 160 | Test instances: 40
Target: NG | Train instances: 1600 | Test instances: 400
Target: GA | Train instances: 382 | Test instances: 95
Target: BW | Train instances: 1600 | Test instances: 400
Target: GM | Train instances: 58 | Test instances: 14
Target: SO | Train instances: 20 | Test instances: 4
Target: TN | Train instances: 136 | Test instances: 33
Target: EG | Train instances: 32 | Test instances: 7
Target: MW | Train instances: 1600 | Test instances: 400
Target: GW | Train instances: 1129 | Test instances: 282
Target: BJ | Train instances: 377 | Test instances: 94
Target: MA | Train instances: 656 | Test instances: 163
Target: AO | Train instances: 632 | Test instances: 158
Target: BF | Train instances: 659 | Test instances: 164
Target: LY | Train instances: 118 | Test instances: 29
Target: ML | Train instances: 351 | Test instances: 87
Target: NE | Train instances: 228 | Test instances: 57
Target: TD | Train instances: 204 | Test instances: 51
Target: BI | Train instances: 434 | Test instances: 108
Target: SD | Train instances: 39 | Test instances: 9
Target: ER | Train instances: 1 | Test instances: 0
Target: DZ | Train instances: 379 | Test instances: 94
Target: MG | Train instances: 59 | Test instances: 14
Target: SC | Train instances: 75 | Test instances: 18
Target: SZ | Train instances: 88 | Test instances: 21
Target: TG | Train instances: 690 | Test instances: 172
Target: GN | Train instances: 117 | Test instances: 29
Target: MR | Train instances: 39 | Test instances: 9
Target: CV | Train instances: 104 | Test instances: 25
Target: SL | Train instances: 144 | Test instances: 35
Target: - | Train instances: 1 | Test instances: 0
Target: RE | Train instances: 10 | Test instances: 2

Total number of training instances: 34693
Total number of testing instances: 8654
In [12]:
##################################################
#####  FEATURES GENERATION BASED ON K-MERS   #####
##################################################
print("\n         FEATURES GENERATION         ")
print("=====================================")
         FEATURES GENERATION         
=====================================
In [13]:
# Initialize an empty dictionary for the k-mers motifs features
instances = {}
# Used to show progress
progress = ProgressBar()
# Iterate through the training data
for d in train_data:
    # Go through the sequence 
    for i in range(0, len(d.seq) - k + 1, 1):
        # Get the current k-mer motif feature
        feature = str(d.seq[i:i + k])
        # If it contains only the characters "A", "C", "G" or "T", it will be saved
        if re.match('^[ACGT]+$', feature): 
            instances[feature] = 0
    progress.update(len(instances))
    # No need to keep going if motifs dictonary reaches max size
    if len(instances) == 4 ** k:
        break
# Used to show progress
progress.finish()
# Save dictonary keys as biopython motifs object
motifs = create(instances.keys())
# Display the number of features
print("\nNumber of features:", len(motifs.instances), "\n")
| |        #                                       | 1024 Elapsed Time: 0:00:00

Number of features: 1024 

In [14]:
######################################################################
##### GENERATION OF THE FEATURE MATRIX (x) AND TARGET VECTOR (y) #####
######################################################################
In [15]:
# Function to generate feature matrix and target vector
def generateFeatures(data):
    # Initialize the feature matrix
    X = []
    # Initialize the target vector
    y = []
    # Used to show progress
    progress = ProgressBar()
    # Iterate through the data
    for d in progress(data):
        # Generate an empty dictionary
        x = {}
        # Initialize the dictionary with targets as keys and 0 as value
        x = x.fromkeys(motifs.instances, 0)
        # Compute X (features matrix): the number of occurrence of k-mers (with overlaping)
        for i in range(0, len(d.seq) - k + 1, 1):
            feature = d.seq[i:i + k]
            # Attempt to increment the number of occurrences of the current k-mer feature
            try: x[feature] += 1
            # It could fail because the current k-mer is not full ACGT
            except: pass
        # Save the features vector in the features matrix
        X.append(list(x.values()))
        # Save the target class in the target vector
        y.append(d.annotations["target"])
    # Return matrices X and y (feature matrix and target vector)
    return X, y
In [16]:
# Generate train/test feature matrices and target vectors
x_train, y_train = generateFeatures(train_data)
x_test, y_test = generateFeatures(test_data)
100% (34693 of 34693) |##################| Elapsed Time: 0:02:07 Time:  0:02:07
100% (8654 of 8654) |####################| Elapsed Time: 0:00:28 Time:  0:00:28
In [17]:
# Function to generate feature matrix and target vector based on k-mer frequency, not the sum
def generateFreqFeatures(x_sum):
    X = []
    for x in x_sum:
        total = sum(x)
        X.append(list(map((lambda i: i / total), x)))
    return X
In [18]:
# If Freq is ture, then the features matrix are frequency of k-mers, not their sum
if freq:
    x_train = generateFreqFeatures(x_train)
    x_test = generateFreqFeatures(x_test)
In [19]:
##############################################
#####       FEATURES NORMALISATION       #####
##############################################
In [20]:
# Instantiate a MinMaxScaler between 0 and 1
minMaxScaler = MinMaxScaler(feature_range = (0,1))
# Apply a scaling to the train and test set
x_train = minMaxScaler.fit_transform(x_train)
x_test = minMaxScaler.fit_transform(x_test)
In [21]:
##############################################
#####         FEATURES SELECTION         #####
##############################################
print("\n         FEATURES SELECTION          ")
print("=====================================")
         FEATURES SELECTION          
=====================================
In [22]:
# Instantiate a linear model based on svm
model = svm.SVC(C = 1.0, kernel='linear', class_weight = None)
# Instantiate the RFE
rfe = RFE(model, n_features_to_select = n_features, step = step, verbose=True)
# Apply RFE and transform the training matrix
x_train = rfe.fit_transform(x_train, y_train)
# Tranform the test matrix (will be useed later for evaluation purposes)
x_test = rfe.transform(x_test)
Fitting estimator with 1024 features.
Fitting estimator with 1019 features.
Fitting estimator with 1014 features.
Fitting estimator with 1009 features.
Fitting estimator with 1004 features.
Fitting estimator with 999 features.
Fitting estimator with 994 features.
Fitting estimator with 989 features.
Fitting estimator with 984 features.
Fitting estimator with 979 features.
Fitting estimator with 974 features.
Fitting estimator with 969 features.
Fitting estimator with 964 features.
Fitting estimator with 959 features.
Fitting estimator with 954 features.
Fitting estimator with 949 features.
Fitting estimator with 944 features.
Fitting estimator with 939 features.
Fitting estimator with 934 features.
Fitting estimator with 929 features.
Fitting estimator with 924 features.
Fitting estimator with 919 features.
Fitting estimator with 914 features.
Fitting estimator with 909 features.
Fitting estimator with 904 features.
Fitting estimator with 899 features.
Fitting estimator with 894 features.
Fitting estimator with 889 features.
Fitting estimator with 884 features.
Fitting estimator with 879 features.
Fitting estimator with 874 features.
Fitting estimator with 869 features.
Fitting estimator with 864 features.
Fitting estimator with 859 features.
Fitting estimator with 854 features.
Fitting estimator with 849 features.
Fitting estimator with 844 features.
Fitting estimator with 839 features.
Fitting estimator with 834 features.
Fitting estimator with 829 features.
Fitting estimator with 824 features.
Fitting estimator with 819 features.
Fitting estimator with 814 features.
Fitting estimator with 809 features.
Fitting estimator with 804 features.
Fitting estimator with 799 features.
Fitting estimator with 794 features.
Fitting estimator with 789 features.
Fitting estimator with 784 features.
Fitting estimator with 779 features.
Fitting estimator with 774 features.
Fitting estimator with 769 features.
Fitting estimator with 764 features.
Fitting estimator with 759 features.
Fitting estimator with 754 features.
Fitting estimator with 749 features.
Fitting estimator with 744 features.
Fitting estimator with 739 features.
Fitting estimator with 734 features.
Fitting estimator with 729 features.
Fitting estimator with 724 features.
Fitting estimator with 719 features.
Fitting estimator with 714 features.
Fitting estimator with 709 features.
Fitting estimator with 704 features.
Fitting estimator with 699 features.
Fitting estimator with 694 features.
Fitting estimator with 689 features.
Fitting estimator with 684 features.
Fitting estimator with 679 features.
Fitting estimator with 674 features.
Fitting estimator with 669 features.
Fitting estimator with 664 features.
Fitting estimator with 659 features.
Fitting estimator with 654 features.
Fitting estimator with 649 features.
Fitting estimator with 644 features.
Fitting estimator with 639 features.
Fitting estimator with 634 features.
Fitting estimator with 629 features.
Fitting estimator with 624 features.
Fitting estimator with 619 features.
Fitting estimator with 614 features.
Fitting estimator with 609 features.
Fitting estimator with 604 features.
Fitting estimator with 599 features.
Fitting estimator with 594 features.
Fitting estimator with 589 features.
Fitting estimator with 584 features.
Fitting estimator with 579 features.
Fitting estimator with 574 features.
Fitting estimator with 569 features.
Fitting estimator with 564 features.
Fitting estimator with 559 features.
Fitting estimator with 554 features.
Fitting estimator with 549 features.
Fitting estimator with 544 features.
Fitting estimator with 539 features.
Fitting estimator with 534 features.
Fitting estimator with 529 features.
Fitting estimator with 524 features.
Fitting estimator with 519 features.
Fitting estimator with 514 features.
Fitting estimator with 509 features.
Fitting estimator with 504 features.
Fitting estimator with 499 features.
Fitting estimator with 494 features.
Fitting estimator with 489 features.
Fitting estimator with 484 features.
Fitting estimator with 479 features.
Fitting estimator with 474 features.
Fitting estimator with 469 features.
Fitting estimator with 464 features.
Fitting estimator with 459 features.
Fitting estimator with 454 features.
Fitting estimator with 449 features.
Fitting estimator with 444 features.
Fitting estimator with 439 features.
Fitting estimator with 434 features.
Fitting estimator with 429 features.
Fitting estimator with 424 features.
Fitting estimator with 419 features.
Fitting estimator with 414 features.
Fitting estimator with 409 features.
Fitting estimator with 404 features.
Fitting estimator with 399 features.
Fitting estimator with 394 features.
Fitting estimator with 389 features.
Fitting estimator with 384 features.
Fitting estimator with 379 features.
Fitting estimator with 374 features.
Fitting estimator with 369 features.
Fitting estimator with 364 features.
Fitting estimator with 359 features.
Fitting estimator with 354 features.
Fitting estimator with 349 features.
Fitting estimator with 344 features.
Fitting estimator with 339 features.
Fitting estimator with 334 features.
Fitting estimator with 329 features.
Fitting estimator with 324 features.
Fitting estimator with 319 features.
Fitting estimator with 314 features.
Fitting estimator with 309 features.
Fitting estimator with 304 features.
Fitting estimator with 299 features.
Fitting estimator with 294 features.
Fitting estimator with 289 features.
Fitting estimator with 284 features.
Fitting estimator with 279 features.
Fitting estimator with 274 features.
Fitting estimator with 269 features.
Fitting estimator with 264 features.
Fitting estimator with 259 features.
Fitting estimator with 254 features.
Fitting estimator with 249 features.
Fitting estimator with 244 features.
Fitting estimator with 239 features.
Fitting estimator with 234 features.
Fitting estimator with 229 features.
Fitting estimator with 224 features.
Fitting estimator with 219 features.
Fitting estimator with 214 features.
Fitting estimator with 209 features.
Fitting estimator with 204 features.
Fitting estimator with 199 features.
Fitting estimator with 194 features.
Fitting estimator with 189 features.
Fitting estimator with 184 features.
Fitting estimator with 179 features.
Fitting estimator with 174 features.
Fitting estimator with 169 features.
Fitting estimator with 164 features.
Fitting estimator with 159 features.
Fitting estimator with 154 features.
Fitting estimator with 149 features.
Fitting estimator with 144 features.
Fitting estimator with 139 features.
Fitting estimator with 134 features.
Fitting estimator with 129 features.
Fitting estimator with 124 features.
Fitting estimator with 119 features.
Fitting estimator with 114 features.
Fitting estimator with 109 features.
Fitting estimator with 104 features.
In [23]:
# Compute the reduction percentage of the feature matrix
reduction_percentage = ((len(motifs.instances) - n_features) / len(motifs.instances) * 100)
# Print the reduction percentage
print("\nReduction percentage:", round(reduction_percentage, 2), "%")
Reduction percentage: 90.23 %
In [24]:
# Initialize the table that will contain the selected features
instances = []
# Save selected k-mers features
for i, mask in enumerate(rfe.support_): 
    if mask == True: instances.append(motifs.instances[i])
# Save table as biopython motifs object
features = create(instances)
In [25]:
##############################################
#####    TRAINING DATA VISUALISATION     #####
##############################################
print("\n     TRAINING DATA VISUALISATION     ")
print("=====================================")
     TRAINING DATA VISUALISATION     
=====================================
In [30]:
# Define the function to draw Scatter Plot
def generateScatterPlot(title, figure_width, figure_height, data, X, y):
    # If 2d dimensions
    if n_components == 2:
        # Initialize a 2-dimensional figure
        fig, ax = plt.subplots(figsize=(figure_width, figure_height))
    # If 3d dimensions
    else:
        # Initialize a 3-dimensional figure
        fig = plt.figure(figsize=(15, 10))
        ax = Axes3D(fig)
    # List of markers
    markers = ["o","+", "^", "x", "*"]
    # List of colors
    colors = ["tab:blue", "tab:orange", 
              "tab:green", "tab:red", 
              "tab:purple", "tab:brown", 
              "tab:pink", "tab:grey", 
              "tab:olive", "tab:cyan",]
    
    # Iterate through the targets
    for i, target in enumerate(y):
        # Set the list of axis positions
        x = []
        y = []
        z = []
        # If the number of targets is less than 10
        if i < 10:
            color = colors[i]
            marker = markers[0]
        # If the number of targets is less than 20
        elif i < 20:
            color = colors[i-10]
            marker = markers[1]
        # If the number of targets is less than 30
        elif i < 30:
            color = colors[i-20]
            marker = markers[2]
        # If the number of targets is less than 40
        elif i < 40:
            color = colors[i-30]
            marker = markers[2]
        # If the number of targets is less than 50
        else:
            color = colors[i-40]
            marker = markers[4]
            
        # Iterate through the data
        for i, d in enumerate(data):
            # If the sequence belongs to the target of interest
            if d.annotations["target"] == target:
                # Save the value of the positions
                x.append(X[i][0])
                y.append(X[i][1])
                if n_components == 3: z.append(X[i][2])
              
        # Add the current scatter plot to the figure
        if n_components == 2:
            ax.scatter(x, y, c = color, label = target, alpha = 0.75, edgecolors = 'none', marker=marker)
        else:
            ax.scatter(x, y, z, c = color, label=target,alpha=0.75, edgecolors='none', marker=marker)

    # Display the grid
    ax.grid(True)
    # Set the legend parameters
    ax.legend(loc = 2, prop = {'size': 10})
    # Set the tite
    plt.title(title)
    # Set axes labels
    if n_components == 2:
        plt.xlabel('PC1')
        plt.ylabel('PC2')
    else: 
        ax.set_xlabel('PC1')
        ax.set_ylabel('PC2')
        ax.set_zlabel('PC3')
    # Displqy the figure
    plt.show()
In [27]:
# Instantiate a TSNE with 3 principal components
tsne = TSNE(n_components = 3, perplexity = 50, verbose=True)
# Apply TSNE to X_train
x_tsne = tsne.fit_transform(x_train)
[t-SNE] Computing 151 nearest neighbors...
[t-SNE] Indexed 34693 samples in 0.005s...
[t-SNE] Computed neighbors for 34693 samples in 20.871s...
[t-SNE] Computed conditional probabilities for sample 1000 / 34693
[t-SNE] Computed conditional probabilities for sample 2000 / 34693
[t-SNE] Computed conditional probabilities for sample 3000 / 34693
[t-SNE] Computed conditional probabilities for sample 4000 / 34693
[t-SNE] Computed conditional probabilities for sample 5000 / 34693
[t-SNE] Computed conditional probabilities for sample 6000 / 34693
[t-SNE] Computed conditional probabilities for sample 7000 / 34693
[t-SNE] Computed conditional probabilities for sample 8000 / 34693
[t-SNE] Computed conditional probabilities for sample 9000 / 34693
[t-SNE] Computed conditional probabilities for sample 10000 / 34693
[t-SNE] Computed conditional probabilities for sample 11000 / 34693
[t-SNE] Computed conditional probabilities for sample 12000 / 34693
[t-SNE] Computed conditional probabilities for sample 13000 / 34693
[t-SNE] Computed conditional probabilities for sample 14000 / 34693
[t-SNE] Computed conditional probabilities for sample 15000 / 34693
[t-SNE] Computed conditional probabilities for sample 16000 / 34693
[t-SNE] Computed conditional probabilities for sample 17000 / 34693
[t-SNE] Computed conditional probabilities for sample 18000 / 34693
[t-SNE] Computed conditional probabilities for sample 19000 / 34693
[t-SNE] Computed conditional probabilities for sample 20000 / 34693
[t-SNE] Computed conditional probabilities for sample 21000 / 34693
[t-SNE] Computed conditional probabilities for sample 22000 / 34693
[t-SNE] Computed conditional probabilities for sample 23000 / 34693
[t-SNE] Computed conditional probabilities for sample 24000 / 34693
[t-SNE] Computed conditional probabilities for sample 25000 / 34693
[t-SNE] Computed conditional probabilities for sample 26000 / 34693
[t-SNE] Computed conditional probabilities for sample 27000 / 34693
[t-SNE] Computed conditional probabilities for sample 28000 / 34693
[t-SNE] Computed conditional probabilities for sample 29000 / 34693
[t-SNE] Computed conditional probabilities for sample 30000 / 34693
[t-SNE] Computed conditional probabilities for sample 31000 / 34693
[t-SNE] Computed conditional probabilities for sample 32000 / 34693
[t-SNE] Computed conditional probabilities for sample 33000 / 34693
[t-SNE] Computed conditional probabilities for sample 34000 / 34693
[t-SNE] Computed conditional probabilities for sample 34693 / 34693
[t-SNE] Mean sigma: 0.125312
[t-SNE] KL divergence after 250 iterations with early exaggeration: 74.890198
[t-SNE] KL divergence after 1000 iterations: 1.234429
In [31]:
# Generate scatter plot of a TSNE
generateScatterPlot(title= "Scatter plot of a two-dimensional TSNE applied to the training data", 
                    figure_width = 15, 
                    figure_height = 12, 
                    data = train_data, 
                    X = x_tsne, 
                    y = set(y_train))
<ipython-input-30-c0904cf06dd2>:59: UserWarning: You passed a edgecolor/edgecolors ('none') for an unfilled marker ('+').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  ax.scatter(x, y, c = color, label = target, alpha = 0.75, edgecolors = 'none', marker=marker)
2021-05-04T10:54:40.108563 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [32]:
# Instantiate PCA with 3 principal components
pca = PCA(n_components = 3)
x_pca =  pca.fit_transform(x_train)
In [33]:
# Generate scatter plot of a PCA
generateScatterPlot(title= "Scatter plot of a two-dimensional PCA applied to the training data", 
                    figure_width = 15, 
                    figure_height = 12, 
                    data = train_data, 
                    X = x_pca, 
                    y = set(y_train))
<ipython-input-30-c0904cf06dd2>:59: UserWarning: You passed a edgecolor/edgecolors ('none') for an unfilled marker ('+').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  ax.scatter(x, y, c = color, label = target, alpha = 0.75, edgecolors = 'none', marker=marker)
2021-05-04T10:55:00.059791 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [34]:
##############################################
#####   MODEL TRAINING AND PREDICTION    #####
##############################################
print("\n    MODEL TRAINING AND PREDICTION    ")
print("=====================================")
    MODEL TRAINING AND PREDICTION    
=====================================
In [35]:
# Fit the model on the train set
model.fit(x_train, y_train)
# Save the model to filename model_name
joblib.dump(model, model_name)
Out[35]:
['afr.pkl']
In [36]:
# Predict with model on the test set
y_pred = model.predict(x_test)
# Display prediction
print("Predictions (" + str(len(y_pred)) + "):", y_pred)
Predictions (8654): ['MZ' 'BW' 'GW' ... 'MZ' 'ZA' 'CF']
In [37]:
##############################################
#####  MODEL PREDICTIONS VISUALISATION   #####
##############################################
print("\n   MODEL PREDICTIONS VISUALISATION   ")
print("=====================================")
   MODEL PREDICTIONS VISUALISATION   
=====================================
In [38]:
# Will contain correct and incorrect data seq_records objects
correct_data = []
incorrect_data = []
# Will contain correct and incorrect features vectors (just like x_test)
correct_features = []
incorrect_features = []
# Iterate through test data
for i, d in enumerate(test_data):
    # Add an annotation to all test data stating its percentage range of ACGT characters
    total_char = len(d.seq)
    total_acgt = 0
    for char in d.seq:
        if re.match('^[ACGT]+$', char):
            total_acgt += 1
    acgt_percent = total_acgt / total_char
    if acgt_percent >= 0.75: d.annotations["acgt-percent"] = "75-100"
    elif acgt_percent >= 0.50: d.annotations["acgt-percent"] = "50-75"
    elif acgt_percent >= 0.25: d.annotations["acgt-percent"] = "25-50"
    else: d.annotations["acgt-percent"] = "0-25"
    # Split test data into correct and incorrect sets depending on prediction results
    if y_pred[i] == d.annotations["target"]:
        correct_data.append(d)
        correct_features.append(x_test[i])
    else:
        # If it's incorrect, add the prediction class as an annotation
        d.annotations["prediction"] = y_pred[i]
        incorrect_data.append(d)
        incorrect_features.append(x_test[i])
In [39]:
# Print the classification_report
print(classification_report(y_test, y_pred, digits = 3))
              precision    recall  f1-score   support

          AO      0.560     0.443     0.495       158
          BF      0.579     0.268     0.367       164
          BI      0.700     0.389     0.500       108
          BJ      0.294     0.053     0.090        94
          BW      0.789     0.682     0.732       400
          CD      0.376     0.522     0.437       400
          CF      0.919     0.768     0.837       310
          CG      0.000     0.000     0.000        40
          CI      0.392     0.616     0.479       229
          CM      0.400     0.445     0.421       400
          CV      1.000     0.120     0.214        25
          DJ      0.000     0.000     0.000        19
          DZ      0.524     0.468     0.494        94
          EG      0.000     0.000     0.000         7
          ET      0.751     0.797     0.773       400
          GA      0.125     0.011     0.019        95
          GH      0.622     0.624     0.623       372
          GM      0.000     0.000     0.000        14
          GN      0.000     0.000     0.000        29
          GQ      0.400     0.029     0.055        68
          GW      0.849     0.755     0.799       282
          KE      0.511     0.573     0.540       400
          LR      0.000     0.000     0.000        15
          LY      1.000     0.483     0.651        29
          MA      0.597     0.681     0.636       163
          MG      0.000     0.000     0.000        14
          ML      0.352     0.517     0.419        87
          MR      0.000     0.000     0.000         9
          MW      0.788     0.585     0.671       400
          MZ      0.429     0.665     0.522       227
          NE      0.338     0.386     0.361        57
          NG      0.502     0.605     0.549       400
          RE      0.000     0.000     0.000         2
          RW      0.699     0.667     0.683       400
          SC      0.000     0.000     0.000        18
          SD      0.000     0.000     0.000         9
          SL      0.000     0.000     0.000        35
          SN      0.486     0.360     0.414       400
          SO      0.000     0.000     0.000         4
          SZ      0.000     0.000     0.000        21
          TD      0.444     0.078     0.133        51
          TG      0.448     0.773     0.567       172
          TN      0.590     0.697     0.639        33
          TZ      0.495     0.485     0.490       400
          UG      0.496     0.682     0.575       400
          ZA      0.540     0.443     0.486       400
          ZM      0.574     0.660     0.614       400
          ZW      0.597     0.797     0.683       400

    accuracy                          0.562      8654
   macro avg      0.399     0.357     0.353      8654
weighted avg      0.559     0.562     0.547      8654

C:\Users\Riccardo\AppData\Roaming\Python\Python39\site-packages\sklearn\metrics\_classification.py:1248: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
C:\Users\Riccardo\AppData\Roaming\Python\Python39\site-packages\sklearn\metrics\_classification.py:1248: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
C:\Users\Riccardo\AppData\Roaming\Python\Python39\site-packages\sklearn\metrics\_classification.py:1248: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
In [40]:
# Dictonaries with pair of annotation -> number of incorrect records with this annotation
subtypes = {}
countries = {}
predictions = {}
acgt_percents = {}
# Iterate through incorrect data
for i in incorrect_data:
    # Increment each kind of annotation with current record values as keys
    subtypes[i.annotations["subtype"]] = subtypes.get(i.annotations["subtype"], 0) + 1
    countries[i.annotations["country"]] = countries.get(i.annotations["country"], 0) + 1
    predictions[i.annotations["prediction"]] = predictions.get(i.annotations["prediction"], 0) + 1
    acgt_percents[i.annotations["acgt-percent"]] = acgt_percents.get(i.annotations["acgt-percent"], 0) + 1
# Display number of incorrect records for each annotation, useful to spot any pattern here
print("Incorrect predictions annotations:")
print("Subtype:", subtypes)
print("Country:", countries)
print("Prediction:", predictions)
print("ACGT percent:", acgt_percents)
Incorrect predictions annotations:
Subtype: {'C': 1173, 'H': 26, '02_AG': 767, 'A1CD': 5, '-': 186, 'A1': 414, 'A': 141, '06_cpx': 80, 'U': 53, 'B': 90, '02A1': 19, 'D': 186, 'K': 4, 'F1': 17, '02G': 15, 'G': 185, '11_cpx': 19, 'O': 16, 'AGHU': 1, '02A3': 8, 'CD': 11, '37_cpx': 5, '0206': 14, 'A2': 9, 'A1C': 32, 'A1U': 5, '02D': 7, '09_cpx': 11, 'A3': 28, 'A1D': 23, 'GH': 1, '02B': 1, 'A1G': 16, 'J': 9, 'M': 5, 'F2': 19, '02F2': 1, 'AC': 4, '01_AE': 33, 'A1H': 1, 'JK': 1, '18_cpx': 6, '26_A5U': 5, '01G': 1, '13_cpx': 6, 'AK': 1, '09A': 1, 'AU': 2, 'AD': 10, '49_cpx': 3, 'A1A2': 4, '02DO': 5, '02U': 4, '45_cpx': 3, 'GK': 1, 'GU': 5, '02A': 4, '22_01A1': 4, 'AJ': 1, 'AG': 11, '11A1': 4, '45C': 1, 'DG': 1, '25_cpx': 3, 'CF1': 2, '0209': 7, 'A1GJ': 1, '02H': 1, 'KU': 2, 'CG': 3, 'BD': 6, '06G': 1, 'AGKU': 1, 'DK': 2, '93_cpx': 1, '01A1': 2, '10_CD': 4, '01U': 1, 'A2D': 2, '0206G': 1, '03_AB': 1, 'A1F1': 1, 'A1A2D': 1, '19_cpx': 2, 'A6': 1, 'A1CG': 1, '09A1': 1, '09G': 1, 'GJ': 2, '16_A2D': 1, '63_02A1': 1, 'A1DG': 1, 'CGU': 1, 'GJO': 1, 'CU': 2, 'A3D': 1, 'F': 1, 'ACD': 1, '92_C2U': 1, '02O': 1, 'A4': 1, 'A1DK': 2, '02A1G': 1, 'A1A6': 1, '05_DF': 1, 'JU': 1}
Country: {'TZ': 206, 'CD': 191, 'GN': 29, 'GH': 140, 'ZM': 136, 'TD': 47, 'UG': 127, 'BF': 120, 'ZW': 81, 'CG': 40, 'CM': 222, 'BW': 127, 'MA': 52, 'ZA': 223, 'MZ': 76, 'SN': 256, 'KE': 171, 'GA': 94, 'BI': 66, 'DJ': 19, 'CI': 88, 'SL': 35, 'RW': 133, 'CV': 22, 'AO': 88, 'DZ': 50, 'LY': 15, 'CF': 72, 'NG': 158, 'TN': 10, 'GW': 69, 'ET': 81, 'GQ': 66, 'BJ': 89, 'MW': 166, 'ML': 42, 'NE': 35, 'GM': 14, 'EG': 7, 'SC': 18, 'MR': 9, 'TG': 39, 'LR': 15, 'SO': 4, 'SZ': 21, 'SD': 9, 'MG': 14, 'RE': 2}
Prediction: {'BW': 73, 'TZ': 198, 'CM': 267, 'UG': 277, 'RW': 115, 'CD': 347, 'CI': 219, 'KE': 219, 'MW': 63, 'MA': 75, 'ZW': 215, 'MZ': 201, 'TG': 164, 'DZ': 40, 'ZM': 196, 'GH': 141, 'ML': 83, 'ZA': 151, 'NG': 240, 'ET': 106, 'CF': 21, 'AO': 55, 'TN': 16, 'GW': 38, 'BI': 18, 'SN': 152, 'NE': 43, 'BF': 32, 'GQ': 3, 'GA': 7, 'BJ': 12, 'TD': 5, 'SL': 1, 'GM': 1}
ACGT percent: {'75-100': 3794}
In [41]:
# Compute the confusion matrix
matrix = confusion_matrix(y_true = y_test, y_pred = y_pred, labels=sorted(targets.keys()))
# Build the heatmap
fig, ax = plt.subplots(figsize=(15, 10))
sns.heatmap(matrix, 
            cmap = 'Blues', 
            annot = True, 
            fmt = ".0f", 
            linewidth = 0.1, 
            xticklabels = sorted(targets.keys()), 
            yticklabels = sorted(targets.keys()))
plt.title("Confusion matrix")
plt.xlabel("Predicted label")
plt.ylabel("True label")
plt.show()
2021-05-04T10:56:34.715291 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [42]:
# Show percentage of occurence of all features for all target classes in both train and test data
matrix = []
# Iterate through features
for i, feature in enumerate(features.instances):
    # Generate an empty dictionary
    x = {}
    # Initialize the dictionary with targets as keys and 0 as value
    x = x.fromkeys(targets.keys(), 0)
    # Count in all train data
    for f, d in zip(x_train, train_data):
        if f[i] > 0: x[d.annotations["target"]] += 1
    # Count in all test data
    for f, d in zip(x_test, test_data):
        if f[i] > 0: x[d.annotations["target"]] += 1
    # Vector of attendance percentage
    vector = []
    # Iterate through the number of instances and the number of occurrences
    for n_instances, n_occurrences in zip(targets.values(), x.values()):
        n_instances = min(n_instances, n_samples)
        # Compute the percentage of k-mers attendance by target
        attendance_percentage = 100 - ((n_instances - n_occurrences) / n_instances * 100)
        # Save the attendance percentage in the specitic vector
        vector.append(int(attendance_percentage))
    # Save the vector of attendance percentage in the heatmap matrix
    matrix.append(vector)
# Build the heatmap
fig, ax = plt.subplots(figsize=(15, 20))
sns.heatmap(matrix, 
            annot = True, 
            fmt = ".0f", 
            cmap = 'Blues_r',
            linewidth = 0.1, 
            xticklabels = targets.keys(), 
            yticklabels = features.instances)
plt.title("Percentage of presence of k-mers according to HIV subtypes")
plt.xlabel("Target")
plt.ylabel("Features")
plt.show()
2021-05-04T10:56:55.091917 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [43]:
# For all incorrect records, compute average feature vectors of all correct records for both true and predicted classes
for i_data, i_features in zip(incorrect_data[0:max_incorrect], incorrect_features[0:max_incorrect]):
    # Both matrices to plot
    true_features = []
    pred_features = []
    # Iterate through correct records
    for c_data, c_features in zip(correct_data, correct_features):
        # Compare only if both records are somewhat similar (either same subtype or acgt-percentage range)
        #if i_data.annotations["subtype"] == c_data.annotations["subtype"] or i_data.annotations["acgt-percent"] == c_data.annotations["acgt-percent"]:
        # If this correct record is in the same class as current incorrect record
        if i_data.annotations["target"] == c_data.annotations["target"]:
            true_features.append(c_features)
        # If this correct record is in the class that the current incorrect record has been predicted to  
        if i_data.annotations["prediction"] == c_data.annotations["target"]:
            pred_features.append(c_features)
    # Compute avergare matrices only if similar correct records are found (avoid div per 0)
    if len(true_features) != 0 and len(pred_features) != 0:
        true_features_mean = np.array(true_features).mean(axis=0)
        pred_features_mean = np.array(pred_features).mean(axis=0)
        # Build the heatmap
        fig, ax = plt.subplots(figsize=(40,5))
        sns.heatmap([true_features_mean, i_features, pred_features_mean], 
                #annot = True, 
                #fmt = ".0f", 
                linewidth = 0.1,
                cmap = 'Blues',
                xticklabels = features.instances,
                yticklabels = ["True", "Incorrect", "Prediction"],)
        plt.title("Comparaison of incorrect features vector with true and predicted features vectors averages")
        plt.xlabel("Features")
        plt.show()
2021-05-04T10:57:09.329917 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:11.825436 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:13.784536 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:15.918033 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:17.799868 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:19.566010 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:21.289624 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:23.191385 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:25.296631 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [44]:
# For all incorrect records, compare apparence percentage of all correct records in both true and predicted classes
for i_data, i_features in zip(incorrect_data[0:max_incorrect], incorrect_features[0:max_incorrect]):
    # Dictionaries containing nb of occurences of features for all correct records
    true_features = {}
    pred_features = {}
    true_features = true_features.fromkeys(features.instances, 0)
    pred_features = pred_features.fromkeys(features.instances, 0)
    true_total = 0
    pred_total = 0
    # Iterate through correct records
    for c_data, c_features in zip(correct_data, correct_features):
        # Compare only if both records are somewhat similar (either same subtype or acgt-percentage range)
        #if i_data.annotations["subtype"] == c_data.annotations["subtype"] or i_data.annotations["acgt-percent"] == c_data.annotations["acgt-percent"]:
        # If this correct record is in the same class as current incorrect record
        if i_data.annotations["target"] == c_data.annotations["target"]:
            true_total += 1
            for value, key in zip(c_features, features.instances):
                if value > 0: true_features[key] += 1
        # If this correct record is in the class that the current incorrect record has been predicted to  
        if i_data.annotations["prediction"] == c_data.annotations["target"]:
            pred_total += 1
            for value, key in zip(c_features, features.instances):
                if value > 0: pred_features[key] += 1
    # Compute avergare matrices only if similar correct records are found (avoid div per 0)
    if true_total != 0 and pred_total != 0:
        true_vector = list(map((lambda i: i / true_total), true_features.values()))
        pred_vector = list(map((lambda i: i / pred_total), pred_features.values()))
        # Build the heatmap
        fig, ax = plt.subplots(figsize=(40,5))
        sns.heatmap([true_vector, i_features, pred_vector], 
                #annot = True, 
                #fmt = ".0f", 
                linewidth = 0.1,
                cmap = 'Blues_r',
                xticklabels = features.instances,
                yticklabels = ["True", "Incorrect", "Prediction"],)
        plt.title("Comparaison of incorrect features vector with true and predicted vectors of occurences percents")
        plt.xlabel("Features")
        plt.show()
2021-05-04T10:57:27.292799 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:29.044258 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:30.674776 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:32.311689 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:33.937692 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:35.647002 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:37.493210 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:39.283526 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
2021-05-04T10:57:41.050853 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [45]:
# Compute alignement of all incorrect records to all correct record and compute avegarge of scores
print("\nComparison of alignement scores between true and predicted class:")
ids = []
matrix = []
# Used to show progress
progress = ProgressBar(max_value=len(incorrect_data[0:max_incorrect])*len(correct_data[0:max_correct])).start()
count = 0
# Shuffle correct data (when we're sampling it)
shuffle(correct_data)
# Iterate through incorrect data
for i in incorrect_data[0:max_incorrect]:
    # Keep different averages for same target class and predicted target class of incorrect record
    true_score_sum = 0
    true_score_nb = 0
    pred_score_sum = 0
    pred_score_nb = 0
    # Iterate through correct data
    for c in correct_data[0:max_correct]:
        # Compare only if both records are somewhat in the same category (both same subtype and acgt-percentage range)
        #if i.annotations["subtype"] == c.annotations["subtype"] and i.annotations["acgt-percent"] == c.annotations["acgt-percent"]:
        # If this correct record is in the same class as current incorrect record
        if i.annotations["target"] == c.annotations["target"]:
            true_score_sum += pairwise2.align.globalxx(i.seq, c.seq, score_only=True)
            true_score_nb += 1
        # If this correct record is in the class that the current incorrect record has been predicted to
        if i.annotations["prediction"] == c.annotations["target"]:
            pred_score_sum += pairwise2.align.globalxx(i.seq, c.seq, score_only=True)
            pred_score_nb += 1
        # Used to show progress
        count += 1
        progress.update(count)
    # Compute avergare only if similar correct records are found (avoid div per 0)
    if true_score_nb != 0 and pred_score_nb != 0:
        ids.append(i.id)
        matrix.append([true_score_sum/true_score_nb, pred_score_sum/pred_score_nb])
# Normalise results
matrix = pd.DataFrame(np.array(matrix))
matrix = matrix.div(matrix.max(axis=1), axis=0)
# Build the heatmap
fig, ax = plt.subplots()
sns.heatmap(matrix, 
            #annot = True, 
            #fmt = ".0f", 
            linewidth = 0.1,
            cmap = 'Blues',
            xticklabels = ["True", "Prediction"], 
            yticklabels = ids)
plt.title("Comparison of alignement scores between true and predicted class")
plt.xlabel("Target")
plt.ylabel("ID")
plt.show()
  0% (4 of 10000) |                      | Elapsed Time: 0:00:00 ETA:   0:04:36
Comparison of alignement scores between true and predicted class:
 99% (9964 of 10000) |################## | Elapsed Time: 0:00:24 ETA:   0:00:00
2021-05-04T10:58:06.776422 image/svg+xml Matplotlib v3.4.1, https://matplotlib.org/
In [46]:
# Tried something, did not work yet...

#features = create(["GGCGG"])
#for i in incorrect_data:
#    graphic_features = []
#    progress = ProgressBar()
#    for pos, seq in progress(features.instances.search(i.seq)):
#        graphic_features.append(GraphicFeature(start = pos, end= pos + k, strand = +1, color= "#ffd700", label=str(seq + "\n" + "Position : " + str(pos))))
#    record = GraphicRecord(sequence_length = len(i.seq), features=graphic_features)
#    record.plot(figure_width = 15)
#    plt.title("Sequence : " + i.id) 
#    plt.show()
#for c in correct_data:
#    if c.annotations["target"] == "CRB":
#        graphic_features = []
#        progress = ProgressBar()
#        for pos, seq in progress(features.instances.search(i.seq)):
#            graphic_features.append(GraphicFeature(start = pos, end= pos + k, strand = +1, color= "#ffd700", label=str(seq + "\n" + "Position : " + str(pos))))
#        record = GraphicRecord(sequence_length = len(i.seq), features=graphic_features)
#        record.plot(figure_width = 15)
#        plt.title("Sequence : " + c.id) 
#        plt.show()
#        break
#for c in correct_data:
#    if c.annotations["target"] == "OCE":
#        graphic_features = []
#        progress = ProgressBar()
#        for pos, seq in progress(features.instances.search(i.seq)):
#            graphic_features.append(GraphicFeature(start = pos, end= pos + k, strand = +1, color= "#ffd700", label=str(seq + "\n" + "Position : " + str(pos))))
#        record = GraphicRecord(sequence_length = len(i.seq), features=graphic_features)
#        record.plot(figure_width = 15)
#        plt.title("Sequence : " + c.id) 
#        plt.show()
#        break